#include "netlib/rpc/protobuf_codec.h"

#include "netlib/net/socket/tcp_connection.h"
#include "rpc.pb.h"

namespace netlib::rpc {

template <typename MSG>
ProtobufCodec<MSG>::ProtobufCodec(ProtobufMessageCallback&& cb) :
    pb_type_(&MSG::default_instance()), msg_callback_(std::move(cb)) {}

template <typename MSG>
void ProtobufCodec<MSG>::Send(const net::TcpConnectionPtr& conn, const ProtobufMessage& msg) const {
	net::Buffer buf;
	assert(msg.ByteSizeLong() <= buf.WritableBytes());
	auto res = msg.SerializeToArray(buf.GetWriteBegin(), buf.WritableBytes());
	assert(res);
	buf.HasWritten(msg.ByteSizeLong());
	buf.AppendCheckSum();
	buf.PrependSize();
	conn->Send(buf);
}

template <typename MSG>
void ProtobufCodec<MSG>::OnMessage(const net::TcpConnectionPtr& conn,
                                   net::Buffer* buf,
                                   Timestamp time) const {
	while (buf->ReadableBytes() >= net::Buffer::kHeaderSizeSize + kMinMessageSize) {
		auto size = buf->PeekInt32();
		if (size > kMaxMessageSize || size < kMinMessageSize) {
			LOG_ERROR << "Wrong message size: " << size;
			break;
		}
		if (buf->ReadableBytes() < size + kHeaderSizeSize) {
			break;
		}

		// check checksum
		auto res = buf->CheckCheckSum(kHeaderSizeSize, kHeaderSizeSize + size - kCheckSumSize);
		if (!res) {
			LOG_ERROR << "Checksum error";
			break;
		}
		// parse message in buffer
		std::unique_ptr<ProtobufMessage> msg(pb_type_->New());
		res = msg->ParseFromArray(buf->Peek() + kHeaderSizeSize, size - kCheckSumSize);
		assert(res);
		buf->Retrieve(kHeaderSizeSize + size);
		msg_callback_(conn, std::move(msg), time);
	}
}

class RpcMessage;
template class ProtobufCodec<RpcMessage>;

} // namespace netlib::rpc